"""train_model.py
   Train and save models
   Developed as part of Recur project
   November 2020
"""

import argparse
import os
import sys
from collections import OrderedDict

import numpy as np
import torch
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter

import warmup
from learning_module import TrainingSetup, TestingSetup, train, test, OptimizerWithSched
from utils import load_model_from_checkpoint, get_dataloaders, to_json, get_optimizer
from utils import to_log_file, now, get_model


# Ignore statements for pylint:
#     Too many branches (R0912), Too many statements (R0915), No member (E1101),
#     Not callable (E1102), Invalid name (C0103), No exception (W0702),
#     Too many local variables (R0914), Missing docstring (C0116, C0115).
# pylint: disable=R0912, R0915, E1101, E1102, C0103, W0702, R0914, C0116, C0115


def main():

    print("\n_________________________________________________\n")
    print(now(), "train_model.py main() running.")

    parser = argparse.ArgumentParser(description="Deep Thinking")
    parser.add_argument("--lr", default=0.1, type=float, help="Classifier learning rate")
    parser.add_argument("--lr_factor", default=0.1, type=float, help="Learning rate decay factor")
    parser.add_argument('--lr_schedule', nargs='+', default=[100, 150], type=int,
                        help='how often to decrease lr')
    parser.add_argument("--epochs", default=200, type=int, help="number of epochs for training")
    parser.add_argument("--train_batch_size", default=128, type=int,
                        help="batch size for training")
    parser.add_argument("--test_batch_size", default=50, type=int,
                        help="batch size for testing")
    parser.add_argument("--test_iterations", default=None, type=int,
                        help="if testing with different iterations than training, use this arg.")
    parser.add_argument("--optimizer", default="SGD", type=str, help="optimizer")
    parser.add_argument("--model", default="resnet18", type=str, help="model for training")
    parser.add_argument("--model_file", default="recur_resnet", type=str,
                        help="model for training")
    parser.add_argument("--dataset", default="CIFAR10", type=str, help="dataset")
    parser.add_argument("--val_period", default=20, type=int, help="print every __ epoch")
    parser.add_argument("--save_period", default=None, type=int, help="print every __ epoch")
    parser.add_argument("--output", default="output_default", type=str, help="output subdirectory")
    parser.add_argument("--checkpoint", default="check_default", type=str,
                        help="where to save the network")
    parser.add_argument("--model_path", default=None, type=str, help="where is the model saved?")
    parser.add_argument("--save_json", action="store_true", help="save json?")
    parser.add_argument("--no_save_log", action="store_true", help="save log file?")
    parser.add_argument("--train_log", default="train_log.txt", type=str,
                        help="name of the log file")
    parser.add_argument("--tol", default=1.0, type=float, help="regression error tolerance")
    parser.add_argument("--maze_size", default=7, type=int, help="size of the actual maze")
    parser.add_argument("--problem", default="classification", type=str,
                        help="Type of problem: classification or regression")
    parser.add_argument("--mode", default=None, type=int, help="which experimental learning mode?")
    parser.add_argument("--test_only", action="store_true", help="just wanna test?")
    parser.add_argument("--test_dataset", type=str, default=None, help="name of the testing dataset")
    args = parser.parse_args()

    if args.save_period is None:
        args.save_period = args.epochs
    print(args)

    args.checkpoint = os.path.join(args.checkpoint,
                                   f'lr={args.lr}_batchsize={args.train_batch_size}_'
                                   f'maze={args.maze_size}_tol={args.tol}')

    # summary writer
    train_log = args.train_log
    try:
        array_task_id = train_log[:-4].split('_')[-1]
    except:
        array_task_id = 1

    if args.test_only:
        args.model_path = os.path.join(args.checkpoint,
                                   f"{args.model}_{args.dataset}_{args.optimizer}"
                                   f"_batchsize={args.train_batch_size}"
                                   f"_epoch={args.epochs-1}"
                                   f"_mode=None"
                                   f"_{array_task_id}.pth")

        tmp_dataset = args.dataset
        args.dataset = args.test_dataset
        args.epochs = 0

    writer = SummaryWriter(log_dir=f"{args.output}/runs/{train_log[:-4]}")

    if not args.no_save_log:
        to_log_file(args, args.output, train_log)

    # set device
    device = "cuda" if torch.cuda.is_available() else "cpu"

    ####################################################
    #               Dataset and Network and Optimizer
    trainloader, testloader = get_dataloaders(args.dataset, args.train_batch_size,
                                              test_batch_size=args.test_batch_size)

    # load model from path if a path is provided
    # The optimizer is not getting reloaded here... this needs to be fixed
    if args.model_path is not None:
        print(f"Loading model from checkpoint {args.model_path}...")
        net, start_epoch, optimizer_state_dict = load_model_from_checkpoint(args.model,
                                                                            args.model_path,
                                                                            args.model_file,
                                                                            args.dataset)

        print(start_epoch)
        start_epoch += 1

        num1 = {"mazes_small":1, "mazes_medium":2, "mazes_large":3}[tmp_dataset.lower()]
        num2 = {"mazes_small":1, "mazes_medium":2, "mazes_large":3}[args.dataset.lower()]
        # if args.mode is not None or num2 >= num1:
        # args.test_iterations = net.iters + 10
        # else:
        #     exit(-1)

    else:
        net = get_model(args.model, args.model_file, args.dataset)
        start_epoch = 0
        optimizer_state_dict = None

    net = net.to(device)
    print(net)
    pytorch_total_params = sum(p.numel() for p in net.parameters())
    print(f"This {args.model} has {pytorch_total_params/1e6:0.3f} million parameters")

    optimizer = get_optimizer(args.optimizer, args.model, net, args.lr)

    if optimizer_state_dict is not None:
        print(f"Loading optimizer from checkpoint {args.model_path}...")
        optimizer.load_state_dict(optimizer_state_dict)
        warmup_scheduler = warmup.ExponentialWarmup(optimizer, warmup_period=0)
    else:
        warmup_scheduler = warmup.ExponentialWarmup(optimizer, warmup_period=5)

    lr_scheduler = MultiStepLR(optimizer, milestones=args.lr_schedule, gamma=args.lr_factor,
                               last_epoch=-1)
    optimizer_obj = OptimizerWithSched(optimizer, lr_scheduler, warmup_scheduler)
    ####################################################

    ####################################################
    #        Train and Test
    np.set_printoptions(precision=2)
    torch.backends.cudnn.benchmark = True

    train_setup = TrainingSetup(args.model, args.tol, args.problem.lower(), args.mode)
    test_setup = TestingSetup(args.model, args.tol, args.problem.lower(), args.mode)

    # choose mode
    if args.mode is not None:
        if args.problem.lower() == "classification":
            from learning_module import train_with_modes as train
            from learning_module import test_with_modes as test
        if args.problem.lower() == "segment":
            from learning_module import train_segment_with_modes as train
            from learning_module import test_segment_with_modes as test

    elif args.problem.lower() == "segment":
        from learning_module import train_segment as train
        from learning_module import test_segment as test

    else:
        from learning_module import train
        from learning_module import test

    print(f"==> Starting training for {args.epochs - start_epoch} epochs...")

    for epoch in range(start_epoch, args.epochs):

        loss, acc = train(net, trainloader, optimizer_obj, train_setup, device)

        if args.mode is None or args.mode == 0:
            print(f"{now()} Training loss at epoch {epoch}: {loss}")
            print(f"{now()} Training accuracy at epoch {epoch}: {acc}")

            # If the loss is nan and is not recovering then stop the training
            if np.isnan(float(loss)):
                print("Loss is nan, exiting...")
                sys.exit()

            # Tensorboard loss writing
            writer.add_scalar("Loss/loss", loss, epoch)
            writer.add_scalar("Accuracy/acc", acc, epoch)

            for i in range(len(optimizer.param_groups)):
                writer.add_scalar(f"Learning_rate/group{i}", optimizer.param_groups[i]['lr'], epoch)

            if (epoch + 1) % args.val_period == 0:
                print(f"{now()} Epoch: {epoch}")
                print(f"{now()} Loss: {loss}")
                print(f"{now()} Training acc: {acc}")

                train_acc = test(net, trainloader, test_setup, device)
                test_acc = test(net, testloader, test_setup, device)

                print(f"{now()} Training accuracy: {train_acc}")
                print(f"{now()} Testing accuracy: {test_acc}")

                stats = [train_acc, test_acc]
                stat_names = ["train_acc", "test_acc"]
                for stat_idx, stat in enumerate(stats):
                    stat_name = os.path.join("val", stat_names[stat_idx])
                    writer.add_scalar(stat_name, stat, epoch)

                # To log file
                log_stats = OrderedDict([("epoch", epoch),
                                         ("loss", loss),
                                         ("training_acc", acc),
                                         ("test_acc", test_acc)])
        elif args.mode == 1:
            print(f"{now()} Training loss at epoch {epoch}: {loss[0]}")
            print(f"{now()} Thinker loss at epoch {epoch}: {loss[1]}")
            print(f"{now()} Training accuracy at epoch {epoch}: {acc}")

            # If the loss is nan and is not recovering then stop the training
            if np.isnan(float(loss[0])):
                print("Loss is nan, exiting...")
                sys.exit()

            # Tensorboard loss writing
            writer.add_scalar("Loss/total_loss", loss[0], epoch)
            writer.add_scalar("Loss/thinker_loss", loss[1], epoch)
            writer.add_scalars("Accuracy/acc", {str(i): acc[i] for i in range(len(acc))}, epoch)

            if (epoch + 1) % args.val_period == 0:
                print(f"{now()} Epoch: {epoch}")
                print(f"{now()} Loss: {loss[0]}")
                print(f"{now()} Thinker Loss: {loss[1]}")
                print(f"{now()} Training acc: {acc}")

                train_stats = test(net, trainloader, test_setup, device)
                train_acc, train_acc_accum, train_acc_agreement = train_stats[0], train_stats[2], \
                                                                  train_stats[3]

                test_stats = test(net, testloader, test_setup, device)
                test_acc, test_acc_accum, test_acc_agreement = test_stats[0], test_stats[2], \
                                                               test_stats[3]

                print(f"{now()} Training accuracy: {train_acc}")
                print(f"{now()} Training acc on agreement: {train_acc_agreement}")
                print(f"{now()} Training acc accumulated: {train_acc_accum}")
                print(f"{now()} Testing accuracy: {test_acc}")
                print(f"{now()} Testing acc on agreement: {test_acc_agreement}")
                print(f"{now()} Testing acc accumulated: {test_acc_accum}")

                stats = [train_acc_accum, test_acc_accum]
                stat_names = ["train_bound", "test_bound"]
                for stat_idx, stat in enumerate(stats):
                    stat_name = os.path.join("val", stat_names[stat_idx])
                    writer.add_scalars(stat_name, {str(i): stat[i] for i in range(len(stat))}, epoch)

                log_stats = OrderedDict([("epoch", epoch),
                                         ("loss", loss),
                                         ("training_acc", acc),
                                         ("test_acc", test_acc),
                                         ("Training_acc_accumulated", train_acc_accum),
                                         ("Testing_acc_accumulated", test_acc_accum),
                                         ("Training_acc_agreement", train_acc_agreement),
                                         ("Testing_acc_agreement", test_acc_agreement)])

            if (epoch + 1) % args.val_period == 0 and not args.no_save_log:
                to_log_file(log_stats, args.output, train_log)

        if (epoch + 1) % args.save_period == 0 or (epoch + 1) % args.epochs == 0:
            state = {
                "net": net.state_dict(),
                "epoch": epoch,
                "optimizer": optimizer.state_dict()
            }
            out_str = os.path.join(args.checkpoint,
                                   f"{args.model}_{args.dataset}_{args.optimizer}"
                                   f"_batchsize={args.train_batch_size}"
                                   f"_epoch={args.epochs-1}"
                                   f"_mode={args.mode}"
                                   f"_{array_task_id}.pth")

            print("saving model to: ", args.checkpoint, " out_str: ", out_str)
            if not os.path.isdir(args.checkpoint):
                os.makedirs(args.checkpoint)
            torch.save(state, out_str)

    writer.flush()
    writer.close()
    ####################################################

    ####################################################
    #        Test
    print("==> Starting testing...")

    if args.test_iterations is not None:
        net.iters = args.test_iterations

    if args.mode is None or args.mode == 0:
        # train_acc = test(net, trainloader, test_setup, device)
        test_acc = test(net, testloader, test_setup, device)

        # print(f"{now()} Training accuracy: {train_acc}")
        print(f"{now()} Testing accuracy: {test_acc}")

        stats = OrderedDict([("model", args.model),
                             ("num_params", pytorch_total_params),
                             ("learning rate", args.lr),
                             ("lr_factor", args.lr_factor),
                             ("lr", args.lr),
                             ("epochs", args.epochs),
                             ("train_batch_size", args.train_batch_size),
                             ("optimizer", args.optimizer),
                             ("dataset", args.dataset),
                             ("train_acc", train_acc),
                             ("test_acc", test_acc),
                             ("test_iter", args.test_iterations)])

    elif args.mode == 1:
        train_stats = test(net, trainloader, test_setup, device)
        train_stats = list(train_stats)
        test_stats = test(net, testloader, test_setup, device)
        test_stats = list(test_stats)

        train_acc = list(train_stats[0])
        train_acc_accum = list(train_stats[2])
        train_acc_agreement = list(train_stats[3])
        test_acc = list(test_stats[0])
        test_acc_accum = list(test_stats[2])
        test_acc_agreement = list(test_stats[3])

        print(f"{now()} Training accuracy: {train_acc}")
        print(f"{now()} Training acc on agreement: {train_acc_agreement}")
        print(f"{now()} Training acc accumulated: {train_acc_accum}")
        print(f"{now()} Testing accuracy: {test_acc}")
        print(f"{now()} Testing acc on agreement: {test_acc_agreement}")
        print(f"{now()} Testing acc accumulated: {test_acc_accum}")

        stats = OrderedDict([("model", args.model),
                             ("learning rate", args.lr),
                             ("lr_factor", args.lr_factor),
                             ("num_params", pytorch_total_params),
                             ("epochs", args.epochs),
                             ("train_batch_size", args.train_batch_size),
                             ("mode", args.mode),
                             ("optimizer", args.optimizer),
                             ("dataset", args.dataset),
                             ("train_acc", train_acc),
                             ("test_acc", test_acc),
                             ("train_acc_accum", train_acc_accum),
                             ("test_acc_accum", test_acc_accum),
                             ("train_acc_agree", train_acc_agreement),
                             ("test_acc_agree", test_acc_agreement),
                             ("test_iter", args.test_iterations)
                             ])

    if args.problem.lower() == "segment":
        stats["tol"] = args.tol
        stats["maze_size"] = args.maze_size

    if args.save_json:
        to_json(stats, args.output)
    ####################################################


if __name__ == "__main__":
    main()
